skip_connection("ml-classification-logistic-regression")
skip_on_livy()
skip_on_arrow_devel()
skip_on_ci()
skip_databricks_connect()
test_that("ml_logistic_regression() default params", {
  test_requires_version("3.0.0")
  sc <- testthat_spark_connection()
  test_default_args(sc, ml_logistic_regression)
})

test_that("ml_logistic_regression() param setting", {
  test_requires_version("3.0.0")
  sc <- testthat_spark_connection()
  test_args <- list(
    fit_intercept = FALSE,
    elastic_net_param = 1e-4,
    reg_param = 1e-5,
    max_iter = 50,
    # `threshold` can't seem to be set when `thresholds` is
    thresholds = c(0.3, 0.7),
    tol = 1e-04,
    weight_col = "wow",
    aggregation_depth = 3,
    # We'll want to enable this, see #1616
    # upper_bounds_on_coefficients = matrix(rep(1, 6), nrow = 3),
    # lower_bounds_on_coefficients = matrix(rep(-1, 6), nrow = 3),
    # upper_bounds_on_intercepts = c(1, 1, 1),
    # lower_bounds_on_intercepts = c(-1, -1, -1),
    features_col = "foo",
    label_col = "bar",
    family = "multinomial",
    prediction_col = "pppppp",
    probability_col = "apweiof",
    raw_prediction_col = "rparprpr"
  )
  test_param_setting(sc, ml_logistic_regression, test_args)
})

test_that("ml_logistic_regression.tbl_spark() works properly", {
  sc <- testthat_spark_connection()

  training <- tibble(
    id = 0:3L,
    text = c(
      "a b c d e spark",
      "b d",
      "spark f g h",
      "hadoop mapreduce"
    ),
    label = c(1, 0, 1, 0)
  )
  test <- tibble(
    id = 4:7L,
    text = c("spark i j k", "l m n", "spark hadoop spark", "apache hadoop")
  )
  training_tbl <- testthat_tbl("training")
  test_tbl <- testthat_tbl("test")

  pipeline <- ml_pipeline(sc) %>%
    ft_tokenizer("text", "words") %>%
    ft_hashing_tf("words", "features", num_features = 1000) %>%
    ml_logistic_regression(max_iter = 10, reg_param = 0.001)

  m1 <- pipeline %>%
    ml_fit(training_tbl)

  expect_warning_on_arrow(
    m1_predictions <- m1 %>%
      ml_transform(test_tbl) %>%
      pull(probability)
  )

  m2 <- training_tbl %>%
    ft_tokenizer("text", "words") %>%
    ft_hashing_tf("words", "features", num_features = 1000) %>%
    ml_logistic_regression(max_iter = 10, reg_param = 0.001)

  expect_warning_on_arrow(
    m2_predictions <- m2 %>%
      ml_transform(test_tbl %>%
                     ft_tokenizer("text", "words") %>%
                     ft_hashing_tf("words", "features", num_features = 1000)) %>%
      pull(probability)
  )

  expect_equal(m1_predictions, m2_predictions)
})

test_that("ml_logistic_regression() agrees with stats::glm()", {
  sc <- testthat_spark_connection()
  set.seed(42)
  iris_weighted <- iris %>%
    mutate(
      weights = rpois(nrow(iris), 1) + 1,
      ones = rep(1, nrow(iris)),
      versicolor = ifelse(Species == "versicolor", 1L, 0L)
    )
  iris_weighted_tbl <- testthat_tbl("iris_weighted")

  r <- glm(versicolor ~ Sepal.Width + Petal.Length + Petal.Width,
    family = binomial(logit), weights = weights,
    data = iris_weighted
  )
  s <- ml_logistic_regression(iris_weighted_tbl,
    formula = "versicolor ~ Sepal_Width + Petal_Length + Petal_Width",
    reg_param = 0L,
    weight_col = "weights"
  )
  expect_equal(unname(coef(r)), unname(coef(s)), tolerance = 1e-5, scale = 1)

  r <- glm(versicolor ~ Sepal.Width + Petal.Length + Petal.Width,
    family = binomial(logit), data = iris_weighted
  )
  s <- ml_logistic_regression(iris_weighted_tbl,
    formula = "versicolor ~ Sepal_Width + Petal_Length + Petal_Width",
    reg_param = 0L,
    weight_col = "ones"
  )
  expect_equal(unname(coef(r)), unname(coef(s)), tolerance = 1e-5, scale = 1)
})

test_that("ml_logistic_regression can fit without intercept", {
  sc <- testthat_spark_connection()
  set.seed(42)
  iris_weighted <- iris %>%
    mutate(
      weights = rpois(nrow(iris), 1) + 1,
      ones = rep(1, nrow(iris)),
      versicolor = ifelse(Species == "versicolor", 1L, 0L)
    )
  iris_weighted_tbl <- testthat_tbl("iris_weighted")
  expect_error(s <- ml_logistic_regression(
    iris_weighted_tbl,
    formula = versicolor ~ Sepal_Width + Petal_Length + Petal_Width,
    fit_intercept = FALSE
  ), NA)
  r <- glm(versicolor ~ Sepal.Width + Petal.Length + Petal.Width - 1, family = binomial(logit), data = iris_weighted)
  expect_equal(unname(coef(r)), unname(coef(s)), tolerance = 1e-5, scale = 1)
})

test_that("ml_logistic_regression() agrees with stats::glm() for reversed categories", {
  sc <- testthat_spark_connection()
  set.seed(42)
  iris_weighted <- iris %>%
    mutate(
      weights = rpois(nrow(iris), 1) + 1,
      ones = rep(1, nrow(iris)),
      versicolor = ifelse(Species == "versicolor", 1L, 0L)
    )
  iris_weighted_tbl <- testthat_tbl("iris_weighted")

  r <- glm(versicolor ~ Sepal.Width + Petal.Length + Petal.Width,
    family = binomial(logit), weights = weights,
    data = iris_weighted
  )
  s <- ml_logistic_regression(iris_weighted_tbl,
    formula = "versicolor ~ Sepal_Width + Petal_Length + Petal_Width",
    reg_param = 0L,
    weight_col = "weights"
  )
  expect_equal(unname(coef(r)), unname(coef(s)), tolerance = 1e-5, scale = 1)

  r <- glm(versicolor ~ Sepal.Width + Petal.Length + Petal.Width,
    family = binomial(logit), data = iris_weighted
  )
  s <- ml_logistic_regression(iris_weighted_tbl,
    formula = "versicolor ~ Sepal_Width + Petal_Length + Petal_Width",
    reg_param = 0L,
    weight_col = "ones"
  )
  expect_equal(unname(coef(r)), unname(coef(s)), tolerance = 1e-5, scale = 1)
})

test_that("ml_logistic_regression.tbl_spark() takes both quoted and unquoted formulas", {
  sc <- testthat_spark_connection()
  iris_weighted_tbl <- testthat_tbl("iris_weighted")
  m1 <- ml_logistic_regression(
    iris_weighted_tbl,
    formula = "versicolor ~ Sepal_Width + Petal_Length + Petal_Width"
  )

  m2 <- ml_logistic_regression(
    iris_weighted_tbl,
    formula = versicolor ~ Sepal_Width + Petal_Length + Petal_Width
  )

  expect_identical(m1$formula, m2$formula)
})

test_that("ml_logistic_regression.tbl_spark() takes 'response' and 'features' columns instead of formula for backwards compatibility", {
  sc <- testthat_spark_connection()
  iris_weighted_tbl <- testthat_tbl("iris_weighted")
  m1 <- ml_logistic_regression(
    iris_weighted_tbl,
    formula = "versicolor ~ Sepal_Width + Petal_Length + Petal_Width"
  )

  m2 <- ml_logistic_regression(
    iris_weighted_tbl,
    response = "versicolor",
    features = c("Sepal_Width", "Petal_Length", "Petal_Width")
  )

  expect_identical(m1$formula, m2$formula)
})

test_that("ml_logistic_regression.tbl_spark() warns when 'response' is a formula and 'features' is specified", {
  sc <- testthat_spark_connection()
  iris_weighted_tbl <- testthat_tbl("iris_weighted")
  expect_warning(
    ml_logistic_regression(iris_weighted_tbl,
      response = versicolor ~ Sepal_Width + Petal_Length + Petal_Width,
      features = c("Sepal_Width", "Petal_Length", "Petal_Width")
    ),
    "'features' is ignored when a formula is specified"
  )
})

test_that("ml_logistic_regression.tbl_spark() errors if 'formula' is specified and either 'response' or 'features' is specified", {
  sc <- testthat_spark_connection()
  iris_weighted_tbl <- testthat_tbl("iris_weighted")
  expect_error(
    ml_logistic_regression(iris_weighted_tbl,
      "versicolor ~ Sepal_Width + Petal_Length + Petal_Width",
      response = "versicolor"
    ),
    "only one of 'formula' or 'response'-'features' should be specified"
  )
  expect_error(
    ml_logistic_regression(iris_weighted_tbl,
      "versicolor ~ Sepal_Width + Petal_Length + Petal_Width",
      features = c("Sepal_Width", "Petal_Length", "Petal_Width")
    ),
    "only one of 'formula' or 'response'-'features' should be specified"
  )
})

test_that("we can fit multinomial models", {
  sc <- testthat_spark_connection()
  test_requires_version("2.1.0", "multinomial models not supported < 2.1.0")

  n <- 200
  data <- data.frame(
    x = seq_len(n),
    y = rep.int(letters[1:4], times = n / 4)
  )

  # fit multinomial model with R (suppress output for tests)
  capture.output(r <- nnet::multinom(y ~ x, data = data))

  # fit multinomial model with Spark
  tbl <- copy_to(sc, data, overwrite = TRUE)
  s <- ml_logistic_regression(tbl, y ~ x)

  # validate that they generate conforming predictions
  # (it seems their parameterizations are different so
  # the underlying models aren't identical, but we should
  # at least confirm they produce conforming predictions)
  train <- data.frame(x = sample(n))

  rp <- predict(r, train)
  sp <- predict(s, copy_to(sc, train, overwrite = TRUE))

  expect_equal(as.character(rp), as.character(sp))
})

test_that("weights column works for logistic regression", {
  sc <- testthat_spark_connection()
  set.seed(42)
  iris_weighted <- iris %>%
    mutate(
      weights = rpois(nrow(iris), 1) + 1,
      ones = rep(1, nrow(iris)),
      versicolor = ifelse(Species == "versicolor", 1L, 0L)
    )
  iris_weighted_tbl <- testthat_tbl("iris_weighted")

  r <- glm(versicolor ~ Sepal.Width + Petal.Length + Petal.Width,
    family = binomial(logit), weights = weights,
    data = iris_weighted
  )
  s <- ml_logistic_regression(iris_weighted_tbl,
    response = "versicolor",
    features = c("Sepal_Width", "Petal_Length", "Petal_Width"),
    reg_param = 0L,
    weight_col = "weights"
  )
  expect_equal(unname(coef(r)), unname(coef(s)), tolerance = 1e-5, scale = 1)

  r <- glm(versicolor ~ Sepal.Width + Petal.Length + Petal.Width,
    family = binomial(logit), data = iris_weighted
  )
  s <- ml_logistic_regression(iris_weighted_tbl,
    response = "versicolor",
    features = c("Sepal_Width", "Petal_Length", "Petal_Width"),
    reg_param = 0L,
    weight_col = "ones"
  )
  expect_equal(unname(coef(r)), unname(coef(s)), tolerance = 1e-5, scale = 1)
})

test_that("logistic regression bounds on coefficients", {
  sc <- testthat_spark_connection()
  test_requires_version("2.2.0", "coefficient bounds require 2.2+")

  iris_tbl <- testthat_tbl("iris")
  lr <- ml_logistic_regression(
    iris_tbl, Species ~ Petal_Width + Sepal_Length,
    upper_bounds_on_coefficients = matrix(rep(1, 6), nrow = 3),
    lower_bounds_on_coefficients = matrix(rep(-1, 6), nrow = 3),
    upper_bounds_on_intercepts = c(1, 1, 1),
    lower_bounds_on_intercepts = c(-1, -1, -1)
  )
  expect_equal(max(coef(lr)), 1)
  expect_equal(min(coef(lr)), -1)
})

test_clear_cache()

